{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Class distance weighted cross-entropy loss for ordinal regression and deep learning -- cement strength dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Implementation of a method for ordinal regression by Polat et al 2022 [1].\n",
    "\n",
    "**Paper reference:**\n",
    "\n",
    "- [1] G Polat, I Ergenc, HT Kani, YO Alahdab, O Atug, A Temizel. \"[Class Distance Weighted Cross-Entropy Loss for Ulcerative Colitis Severity Estimation](https://arxiv.org/abs/2202.05167).\" arXiv preprint arXiv:1612.00775 (2022)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0 -- Obtaining and preparing the cement_strength dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will be using the cement_strength dataset from [https://github.com/gagolews/ordinal_regression_data/blob/master/cement_strength.csv](https://github.com/gagolews/ordinal_regression_data/blob/master/cement_strength.csv).\n",
    "\n",
    "First, we are going to download and prepare the and save it as CSV files locally. This is a general procedure that is not specific to CORN.\n",
    "\n",
    "This dataset has 5 ordinal labels (1, 2, 3, 4, and 5). Note that we require labels to be starting at 0, which is why we subtract \"1\" from the label column."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of features: 8\n",
      "Number of examples: 998\n",
      "Labels: [0 1 2 3 4]\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "data_df = pd.read_csv(\"https://raw.githubusercontent.com/gagolews/ordinal_regression_data/master/cement_strength.csv\")\n",
    " \n",
    "data_df[\"response\"] = data_df[\"response\"]-1 # labels should start at 0\n",
    "\n",
    "data_labels = data_df[\"response\"]\n",
    "data_features = data_df.loc[:, [\"V1\", \"V2\", \"V3\", \"V4\", \"V5\", \"V6\", \"V7\", \"V8\"]]\n",
    "\n",
    "print('Number of features:', data_features.shape[1])\n",
    "print('Number of examples:', data_features.shape[0])\n",
    "print('Labels:', np.unique(data_labels.values))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Split into training and test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sebastianraschka/miniforge3/lib/python3.9/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.1\n",
      "  warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
     ]
    }
   ],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(\n",
    "    data_features.values,\n",
    "    data_labels.values,\n",
    "    test_size=0.2,\n",
    "    random_state=1,\n",
    "    stratify=data_labels.values)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Standardize features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "\n",
    "sc = StandardScaler()\n",
    "X_train_std = sc.fit_transform(X_train)\n",
    "X_test_std = sc.transform(X_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1 -- Setting up the dataset and dataloader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this section, we set up the data set and data loaders. This is a general procedure that is not specific to CORN. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training on cpu\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Hyperparameters\n",
    "random_seed = 1\n",
    "learning_rate = 0.0001\n",
    "num_epochs = 50\n",
    "batch_size = 128\n",
    "\n",
    "# Architecture\n",
    "NUM_CLASSES = 5\n",
    "\n",
    "# Other\n",
    "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print('Training on', DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "\n",
    "\n",
    "class MyDataset(Dataset):\n",
    "\n",
    "    def __init__(self, feature_array, label_array, dtype=np.float32):\n",
    "    \n",
    "        self.features = feature_array.astype(np.float32)\n",
    "        self.labels = label_array\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        inputs = self.features[index]\n",
    "        label = self.labels[index]\n",
    "        return inputs, label\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.labels.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input batch dimensions: torch.Size([128, 8])\n",
      "Input label dimensions: torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "\n",
    "# Note transforms.ToTensor() scales input images\n",
    "# to 0-1 range\n",
    "train_dataset = MyDataset(X_train_std, y_train)\n",
    "test_dataset = MyDataset(X_test_std, y_test)\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset,\n",
    "                          batch_size=batch_size,\n",
    "                          shuffle=True, # want to shuffle the dataset\n",
    "                          num_workers=0) # number processes/CPUs to use\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset,\n",
    "                         batch_size=batch_size,\n",
    "                         shuffle=False,\n",
    "                         num_workers=0)\n",
    "\n",
    "# Checking the dataset\n",
    "for inputs, labels in train_loader:  \n",
    "    print('Input batch dimensions:', inputs.shape)\n",
    "    print('Input label dimensions:', labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2 - Implementing the class distance weighted cross-entropy loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "According to the paper, the loss is described as follows:\n",
    "\n",
    "$$\\mathbf{C D WC E}=-\\sum_{i=0}^{N-1} \\log (1-\\hat{y}) \\times|i-c|^{\\text {power }}$$\n",
    "\n",
    "where\n",
    "\n",
    "- $N$: the number of class labels\n",
    "- $\\hat{y}$: the predicted scores\n",
    "- $c$: ground-truth class\n",
    "- power: a hyperparameter term that determines the strength of the cost coefficient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.3792, 0.3104, 0.3104],\n",
       "        [0.3072, 0.4147, 0.2780],\n",
       "        [0.4263, 0.2248, 0.3490],\n",
       "        [0.2668, 0.2978, 0.4354]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "targets = torch.tensor([0, 2, 1, 2])\n",
    "\n",
    "logits = torch.tensor( [[-0.3,  -0.5, -0.5], # each row is 1 training example\n",
    "                        [-0.4,  -0.1, -0.5],\n",
    "                        [-0.3,  -0.94, -0.5],\n",
    "                        [-0.99, -0.88, -0.5]])\n",
    "\n",
    "probas = F.softmax(logits, dim=1)\n",
    "probas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([12.2654, 12.2824,  0.9848, 10.2828])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def cdw_ce_loss_naive1(probas, targets, power=5):\n",
    "    \n",
    "    loss = torch.zeros(probas.shape[0])\n",
    "    for example in range(probas.shape[0]):\n",
    "        for i in range(probas.shape[1]):\n",
    "            loss[example] += -torch.log(1-probas[example, i]) * torch.abs(i - targets[example])**power\n",
    "    \n",
    "    return loss\n",
    "        \n",
    "cdw_ce_loss_naive1(probas, targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "204 µs ± 3.06 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit cdw_ce_loss_naive1(probas, targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([12.2654, 12.2824,  0.9848, 10.2828])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def cdw_ce_loss_naive2(probas, targets, power=5):\n",
    "    \n",
    "    loss = 0.\n",
    "    for i in range(probas.shape[1]):\n",
    "        loss += (-torch.log(1-probas[:, i]) * torch.abs(i - targets)**power)\n",
    "        \n",
    "    return loss\n",
    "        \n",
    "cdw_ce_loss_naive2(probas, targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "44.8 µs ± 542 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit cdw_ce_loss_naive2(probas, targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([12.2654, 12.2824,  0.9848, 10.2828])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def cdw_ce_loss_naive3(probas, targets, power=5):\n",
    "    \n",
    "    labels = torch.arange(probas.shape[1]).repeat(probas.shape[0], 1)\n",
    "    loss = (-torch.log(1-probas) * torch.abs(labels - targets.reshape(probas.shape[0], 1))**power).sum(dim=1)\n",
    "        \n",
    "    return loss\n",
    "        \n",
    "cdw_ce_loss_naive3(probas, targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "19.3 µs ± 95.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit cdw_ce_loss_naive3(probas, targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(8.9538)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def cdw_ce_loss(logits, targets, power=5, reduction=\"mean\"):\n",
    "    \n",
    "    probas = torch.softmax(logits, dim=1)\n",
    "    labels = torch.arange(probas.shape[1]).repeat(probas.shape[0], 1)\n",
    "    loss = (-torch.log(1-probas) * torch.abs(labels - targets.reshape(probas.shape[0], 1))**power).sum(dim=1)\n",
    "        \n",
    "    if reduction == \"none\":\n",
    "        return loss\n",
    "    elif reduction == \"sum\":\n",
    "        return loss.sum()\n",
    "    elif reduction == \"mean\":\n",
    "        return loss.mean()    \n",
    "    else:\n",
    "        raise ValueError(\"reduction must be 'none', 'sum', or 'mean'\")    \n",
    "\n",
    "cdw_ce_loss(logits, targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3 - Implementing a model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, in_features, num_classes, num_hidden_1=300, num_hidden_2=300):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.num_classes = num_classes\n",
    "        self.my_network = torch.nn.Sequential(\n",
    "            \n",
    "            # 1st hidden layer\n",
    "            torch.nn.Linear(in_features, num_hidden_1, bias=False),\n",
    "            torch.nn.LeakyReLU(),\n",
    "            torch.nn.Dropout(0.2),\n",
    "            torch.nn.BatchNorm1d(num_hidden_1),\n",
    "            \n",
    "            # 2nd hidden layer\n",
    "            torch.nn.Linear(num_hidden_1, num_hidden_2, bias=False),\n",
    "            torch.nn.LeakyReLU(),\n",
    "            torch.nn.Dropout(0.2),\n",
    "            torch.nn.BatchNorm1d(num_hidden_2),\n",
    "            \n",
    "            # Output layer\n",
    "            torch.nn.Linear(num_hidden_2, num_classes)\n",
    "        )\n",
    "                \n",
    "    def forward(self, x):\n",
    "        logits = self.my_network(x)\n",
    "        return logits\n",
    "    \n",
    "    \n",
    "    \n",
    "torch.manual_seed(random_seed)\n",
    "model = MLP(in_features=8, num_classes=NUM_CLASSES)\n",
    "model.to(DEVICE)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.005)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4 - Using the CDWCE loss for model training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/050 | Batch 000/007 | Loss: 134.0854\n",
      "Epoch: 002/050 | Batch 000/007 | Loss: 16.1060\n",
      "Epoch: 003/050 | Batch 000/007 | Loss: 16.2816\n",
      "Epoch: 004/050 | Batch 000/007 | Loss: 12.1848\n",
      "Epoch: 005/050 | Batch 000/007 | Loss: 9.2307\n",
      "Epoch: 006/050 | Batch 000/007 | Loss: 5.7477\n",
      "Epoch: 007/050 | Batch 000/007 | Loss: 4.4735\n",
      "Epoch: 008/050 | Batch 000/007 | Loss: 4.3764\n",
      "Epoch: 009/050 | Batch 000/007 | Loss: 4.5788\n",
      "Epoch: 010/050 | Batch 000/007 | Loss: 3.4303\n",
      "Epoch: 011/050 | Batch 000/007 | Loss: 4.2646\n",
      "Epoch: 012/050 | Batch 000/007 | Loss: 2.7244\n",
      "Epoch: 013/050 | Batch 000/007 | Loss: 3.5027\n",
      "Epoch: 014/050 | Batch 000/007 | Loss: 3.1227\n",
      "Epoch: 015/050 | Batch 000/007 | Loss: 1.9005\n",
      "Epoch: 016/050 | Batch 000/007 | Loss: 4.4430\n",
      "Epoch: 017/050 | Batch 000/007 | Loss: 2.2113\n",
      "Epoch: 018/050 | Batch 000/007 | Loss: 2.9496\n",
      "Epoch: 019/050 | Batch 000/007 | Loss: 2.4737\n",
      "Epoch: 020/050 | Batch 000/007 | Loss: 2.2458\n",
      "Epoch: 021/050 | Batch 000/007 | Loss: 2.2490\n",
      "Epoch: 022/050 | Batch 000/007 | Loss: 2.5528\n",
      "Epoch: 023/050 | Batch 000/007 | Loss: 3.3671\n",
      "Epoch: 024/050 | Batch 000/007 | Loss: 1.6624\n",
      "Epoch: 025/050 | Batch 000/007 | Loss: 1.5456\n",
      "Epoch: 026/050 | Batch 000/007 | Loss: 1.8861\n",
      "Epoch: 027/050 | Batch 000/007 | Loss: 1.5842\n",
      "Epoch: 028/050 | Batch 000/007 | Loss: 2.4168\n",
      "Epoch: 029/050 | Batch 000/007 | Loss: 1.6376\n",
      "Epoch: 030/050 | Batch 000/007 | Loss: 1.8073\n",
      "Epoch: 031/050 | Batch 000/007 | Loss: 2.5007\n",
      "Epoch: 032/050 | Batch 000/007 | Loss: 1.4211\n",
      "Epoch: 033/050 | Batch 000/007 | Loss: 1.9054\n",
      "Epoch: 034/050 | Batch 000/007 | Loss: 1.3790\n",
      "Epoch: 035/050 | Batch 000/007 | Loss: 2.0045\n",
      "Epoch: 036/050 | Batch 000/007 | Loss: 2.9551\n",
      "Epoch: 037/050 | Batch 000/007 | Loss: 1.3715\n",
      "Epoch: 038/050 | Batch 000/007 | Loss: 1.7470\n",
      "Epoch: 039/050 | Batch 000/007 | Loss: 1.7664\n",
      "Epoch: 040/050 | Batch 000/007 | Loss: 1.3839\n",
      "Epoch: 041/050 | Batch 000/007 | Loss: 1.1422\n",
      "Epoch: 042/050 | Batch 000/007 | Loss: 1.1026\n",
      "Epoch: 043/050 | Batch 000/007 | Loss: 1.4556\n",
      "Epoch: 044/050 | Batch 000/007 | Loss: 1.0135\n",
      "Epoch: 045/050 | Batch 000/007 | Loss: 1.7183\n",
      "Epoch: 046/050 | Batch 000/007 | Loss: 1.2620\n",
      "Epoch: 047/050 | Batch 000/007 | Loss: 1.5380\n",
      "Epoch: 048/050 | Batch 000/007 | Loss: 1.0275\n",
      "Epoch: 049/050 | Batch 000/007 | Loss: 1.3223\n",
      "Epoch: 050/050 | Batch 000/007 | Loss: 1.1917\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(num_epochs):\n",
    "    \n",
    "    model = model.train()\n",
    "    for batch_idx, (features, class_labels) in enumerate(train_loader):\n",
    "\n",
    "        class_labels = class_labels.to(DEVICE)\n",
    "        features = features.to(DEVICE)\n",
    "        logits = model(features)\n",
    "        \n",
    "        loss = cdw_ce_loss(logits, class_labels)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 200:\n",
    "            print ('Epoch: %03d/%03d | Batch %03d/%03d | Loss: %.4f' \n",
    "                   %(epoch+1, num_epochs, batch_idx, \n",
    "                     len(train_loader), loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5 -- Evaluate model\n",
    "\n",
    "Finally, after model training, we can evaluate the performance of the model. For example, via the mean absolute error and mean squared error measures.\n",
    "\n",
    "For this, we are going to use the `beckham_logits_to_labels` to convert the logits into ordinal class labels as shown below:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def logits_to_labels(logits, model):\n",
    "    predictions = torch.argmax(logits, dim=1)\n",
    "    return predictions\n",
    "    \n",
    "\n",
    "def compute_mae_and_mse(model, data_loader, device):\n",
    "\n",
    "    with torch.inference_mode():\n",
    "    \n",
    "        mae, mse, acc, num_examples = 0., 0., 0., 0\n",
    "\n",
    "        for i, (features, targets) in enumerate(data_loader):\n",
    "\n",
    "            features = features.to(device)\n",
    "            targets = targets.float().to(device)\n",
    "\n",
    "            logits = model(features)\n",
    "            predicted_labels = logits_to_labels(logits, model)\n",
    "\n",
    "            num_examples += targets.size(0)\n",
    "            mae += torch.sum(torch.abs(predicted_labels - targets))\n",
    "            mse += torch.sum((predicted_labels - targets)**2)\n",
    "\n",
    "        mae = mae / num_examples\n",
    "        mse = mse / num_examples\n",
    "        return mae, mse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_mae, train_mse = compute_mae_and_mse(model, train_loader, DEVICE)\n",
    "test_mae, test_mse = compute_mae_and_mse(model, test_loader, DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean absolute error (train/test): 0.36 | 0.41\n",
      "Mean squared error (train/test): 0.37 | 0.46\n"
     ]
    }
   ],
   "source": [
    "print(f'Mean absolute error (train/test): {train_mae:.2f} | {test_mae:.2f}')\n",
    "print(f'Mean squared error (train/test): {train_mse:.2f} | {test_mse:.2f}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
